from gxl_ai_utils.utils import utils_file
import torch
raw_ckpt_path = "/mnt/apdcephfs_sgfd/share_303841515/Tealab/user/xuelonggeng/ckpt/fsq_from_stage1_step_40499_multi_node33_120W_stage3_causal/step_28499.pt/pytorch_model.bin"
output_encoder_ckpt_path = "/mnt/apdcephfs_sgfd/share_303841515/Tealab/user/xuelonggeng/ckpt/fsq_from_stage1_step_40499_multi_node33_120W_stage3_causal/step_28499_only_encoder.pt"
param_dict = torch.load(raw_ckpt_path, map_location="cpu")
encoder_param_dict = {}
for k, v in param_dict.items():
    if k.startswith("encoder."):
        new_k = k[len("encoder."):]
        encoder_param_dict[new_k] = v
torch.save(encoder_param_dict, output_encoder_ckpt_path)